# -*- coding: utf-8 -*- #
"""
Time                2023/08/24 15:41
Author:             mingfeng (SunnyQjm)
Email               mfeng@linux.alibaba.com
File                crud.py
Description:
"""
from typing import List, Tuple, Optional
from clogger import logger
from sqlalchemy.orm import Session, joinedload
from app import models, schemas, query
from sysom_utils import dict_merge


class CrudException(Exception):
    pass


################################################################################################
# AlertData
################################################################################################


def create_alert_data(
    db: Session, alert_data: schemas.AlertDataCreate
) -> models.AlertData:
    """Create Alert Data

    Args:
        db (Session): _description_
        alert_data (schemas.AlertData): _description_

    Returns:
        models.AlertData: _description_
    """
    db_alert_data = models.AlertData(**alert_data.dict())
    db.add(db_alert_data)
    db.commit()
    db.refresh(db_alert_data)
    return db_alert_data


def get_alert_data_by_id(db: Session, id_: int) -> Optional[models.AlertData]:
    return db.query(models.AlertData).get(id)


def get_alert_data_by_alert_id(
    db: Session, alert_id: str
) -> Optional[models.AlertData]:
    return (
        db.query(models.AlertData).filter(models.AlertData.alert_id == alert_id).first()
    )


def update_or_create_alert_data(
    db: Session, alert_data: schemas.AlertDataCreate
) -> models.AlertData:
    db_alert_data = get_alert_data_by_alert_id(db, alert_data.alert_id)
    if db_alert_data is None:
        db_alert_data = create_alert_data(db, alert_data)
    else:
        db_alert_data.status = alert_data.status
        db.commit()
    return db_alert_data


def change_alert_deal_status(
    db: Session, ids: List[int], alert_deal_status
) -> Optional[models.AlertData]:
    res = (
        db.query(models.AlertData)
        .filter(models.AlertData.id.in_(ids))
        .update({"deal_status": alert_deal_status})
    )
    db.commit()
    return res

def change_alert_deal_status_by_alert_id(
    db: Session, alert_ids: List[str], alert_deal_status
) -> Optional[models.AlertData]:
    res = (
        db.query(models.AlertData)
        .filter(models.AlertData.alert_id.in_(alert_ids))
        .update({"deal_status": alert_deal_status})
    )
    db.commit()
    return res


def append_alert_annotations(
    db, alert_id: str, extra_annotations: dict
) -> models.AlertData:
    db_alert_data = get_alert_data_by_alert_id(db, alert_id)
    if db_alert_data is None:
        raise CrudException(f"Alert data not exists for alert_id = {alert_id}")
    result_annotations: dict = {}
    dict_merge(result_annotations, db_alert_data.annotations)
    dict_merge(result_annotations, extra_annotations)
    db_alert_data.annotations = result_annotations
    db.commit()
    return db_alert_data


def merge_alert_data(
    db: Session, merge_list: List[str], new_alert_data: dict
) -> models.AlertData:
    db_alert_data = models.AlertData(**new_alert_data)
    merged_alert_datas = (
        db.query(models.AlertData)
        .filter(models.AlertData.alert_id.in_(merge_list))
        .all()
    )
    for item in merged_alert_datas:
        item.deal_status = 2
        db_alert_data.merged_alerts.append(item)
    db.add(db_alert_data)
    db.commit()
    db.refresh(db_alert_data)
    return db_alert_data


def get_alert_datas(
    db: Session, query_params: query.AlertDataQueryParams
) -> Tuple[List[models.AlertData], int]:
    """Get alert_data list

    Args:
        db (Session): _description_
        query_params (query.AlertDataQueryParams): _description_
    """
    # 1. Get total count after apply filter
    total_count = query_params.get_count_by(db, models.AlertData.id)

    # 2. Get alert list after apply sorter and paging
    alert_list = (
        query_params.get_query_builder(db)
        .apply_filter()
        .apply_sorter()
        .apply_paging()
        .build()
        .options(joinedload(models.AlertData.merged_alerts))
        .all()
    )

    return alert_list, total_count
